package main

import (
	"database/sql"
	"errors"
	"fmt"
	"io/ioutil"
	"os"
	"strconv"
	"strings"
	"sync"
	"time"

	_ "github.com/go-sql-driver/mysql"
	"github.com/panjf2000/ants/v2"
)

type ObConfig struct {
	username    string
	tenant_name string
	password    string
	host        string
	port        int

	default_db     string
	max_open_conns int
	max_idle_conns int
}

func OpenDB(config ObConfig) (*sql.DB, error) {
	url := fmt.Sprintf("%s@%s:%s@tcp(%s:%d)/%s", config.username, config.tenant_name, config.password, config.host, config.port, config.default_db)
	fmt.Printf("ob url: %s\n", url)
	db, err := sql.Open("mysql", url)
	if err != nil {
		return nil, err
	}

	db.SetConnMaxIdleTime(24 * time.Hour)
	db.SetConnMaxLifetime(648 * time.Hour)
	db.SetMaxOpenConns(config.max_open_conns)
	db.SetMaxIdleConns(config.max_idle_conns)

	return db, nil
}

func ExecQuery(db *sql.DB, sql string) error {
	// 执行sql
	rows, err := db.Query(sql)
	if err != nil {
		fmt.Sprintf("db.Raw(sql).Rows() %v\n", err)
		return err
	}

	columns, err := rows.Columns()
	if err != nil {
		fmt.Sprintf("rows.Columns() %v\n", err)
		return err
	}

	count := len(columns)
	if count > 0 {
		for _, col_name := range columns {
			fmt.Printf("%s\t|", col_name)
		}
		fmt.Printf("\n-----------------------------------------------------------------------------------------------\n")
	}

	if nil != rows {
		for rows.Next() {
			values := make([]interface{}, count)
			valuePtrs := make([]interface{}, count)
			for i := 0; i < count; i++ {
				valuePtrs[i] = &values[i]
			}
			rows.Scan(valuePtrs...)
			for i, _ := range columns {
				var v interface{}
				val := values[i]
				b, ok := val.([]byte)
				if ok {
					v = string(b)
				} else {
					v = val
				}
				fmt.Printf(" %v\t|", v)
			}

			fmt.Println()
		}
	}

	rows.Close()

	return nil
}

func initParameters(args []string) (ObConfig, int, bool, error) {
	arg_num := len(args)
	fmt.Printf("%d %v\n", arg_num, args)
	if arg_num > 1 {
		for i, v := range args {
			fmt.Printf("key:%d  value:%s\n", i, v)
		}
	}
	config := ObConfig{
		username:    "root", // lynn
		tenant_name: "obmysql",
		password:    "12",        // inspur@17
		host:        "127.0.0.1", //node1.cloud.com
		port:        2881,        //9107

		default_db:     "oceanbase",
		max_open_conns: 50,
		max_idle_conns: 10,
	}

	if arg_num > 6 {
		port, err := strconv.Atoi(args[5])
		if err != nil {
			fmt.Errorf("server port is error: %v\n", err)
			return config, 1, false, errors.New("Please input a valid server port")
		}

		config = ObConfig{
			username:    args[1],
			tenant_name: args[2],
			password:    args[3],
			host:        args[4],
			port:        port,

			default_db:     args[6],
			max_open_conns: 50,
			max_idle_conns: 10,
		}
	}

	start_num := 1
	if arg_num > 7 {
		num, err := strconv.Atoi(args[7])
		if err != nil {
			fmt.Errorf("Please input a valid Query start number error: %v\n", err)
			return config, start_num, false, errors.New("Please input a valid Query start number")
		}

		start_num = num
	}

	async := false
	if arg_num > 8 {
		val, err := strconv.ParseBool(args[8])
		if err != nil {
			fmt.Errorf("Please input a valid Async flag (bool) error: %v\n", err)
			return config, start_num, false, errors.New("Please input a valid Async flag (bool)")
		}
		async = val
	}

	return config, start_num, async, nil
}

func main() {
	config, start_num, async, err := initParameters(os.Args)
	fmt.Printf("config: %v, start_num: %d, async: %t\n", config, start_num, async)

	cost := true

	db, err := OpenDB(config)
	if err != nil {
		fmt.Errorf("open db config %v error: %v\n", config, err)
		return
	}

	defer ants.Release()
	p, _ := ants.NewPool(10)
	defer p.Release()

	var wg sync.WaitGroup
	for i := start_num; i < 23; i++ {
		num := i
		if async {
			wg.Add(1)
			ants.Submit(func() {
				exec_all(db, num, cost)
				wg.Done()
			})
		} else {
			exec_all(db, i, cost)
		}
	}

	if async {
		wg.Wait()
	}
}

func exec_all(db *sql.DB, i int, cost bool) {
	sql, err := readSQL(i)
	if err != nil {
		fmt.Errorf("read Q %d err: %v\n", i, err)
	}

	// fmt.Printf("start %d\n", start)
	if 15 == i {
		sqls := strings.Split(sql, ";")
		for j, s := range sqls {
			fmt.Printf("Q%d[%d]:%s\n", i, j, s)
			if j == 2 {
				err := exec(db, s, fmt.Sprintf("Q[%d][%d]", i, j), cost)
				if err != nil {
					msg := fmt.Sprintf("ExecQuery Q %d[%d] err: %v\n", i, j, err)
					panic(msg)
				}
			} else {
				err := exec(db, s, fmt.Sprintf("Q[%d][%d]", i, j), cost)
				if err != nil {
					msg := fmt.Sprintf("ExecQuery Q %d[%d] err: %v\n", i, j, err)
					panic(msg)
				}
			}
		}
	} else {
		fmt.Printf("Q%d:%s\n", i, sql)
		err := exec(db, sql, fmt.Sprintf("Q[%d]", i), cost)
		if err != nil {
			msg := fmt.Sprintf("ExecQuery Q %d err: %v\n", i, err)
			panic(msg)
		}
	}
}

/* func asyncRun() {
	defer ants.Release()
	// Use the common pool.
	var wg sync.WaitGroup

	// Use the pool with a function,
	// set 10 to the capacity of goroutine pool and 1 second for expired duration.
	p, _ := ants.NewPoolWithFunc(10, func(i interface{}) {
		myFunc(i)
		wg.Done()
	})
	defer p.Release()
} */

func exec(db *sql.DB, sql, label string, cost bool) error {
	if "" == strings.Trim(strings.ReplaceAll(sql, "\n", ""), "") {
		fmt.Printf("sql is  nil\n")
		return nil
	}
	start := time.Now().UnixNano()
	err := ExecQuery(db, sql)
	if cost {
		// fmt.Printf("end %d\n", time.Now().UnixNano())
		fmt.Printf("%s cost %d ms\n", label, (time.Now().UnixNano()-start)/1000000)
	}

	return err
}

func readSQL(num int) (string, error) {
	file_name := fmt.Sprintf("queries/db%d.sql", num)
	fmt.Printf("Read Q%d From %s\n", num, file_name)

	bytes, err := ioutil.ReadFile(file_name)
	if nil != err {
		return "", err
	}
	return string(bytes), nil
}
